import torch

input_mask = torch.zeros(5, 5, dtype=torch.long)
print(input_mask)
input_mask[0:4, :2].fill_(1)
print(input_mask)
